// Copyright 1998-2016 Glenn McIntosh
// licensed under the GNU General Public Licence version 3

// include files
#include "natural.h"
#include "ntt.h"
#include <stdexcept>
template<typename T> int isize(T t) {return static_cast<int>(t.size());}

namespace math
{
// addition
N &N::operator+=(const N &rhs)
{
	if (rhs.value.size() > value.size())
		value.insert(value.end(), rhs.value.size()-value.size(), 0);
	atom2 carry = 0;
	int i = 0;
	while (i < isize(rhs.value))
	{
		carry += value[i]; carry += rhs.value[i];
		value[i++] = carry;
		carry >>= sizeof(carry)/2*8;
	}
	while (carry && i < isize(value))
	{
		carry += value[i];
		value[i++] = carry;
		carry >>= sizeof(carry)/2*8;
	}
	if (carry)
		value.insert(value.begin()+i, carry);
	return *this;
}

// subtraction
N &N::operator-=(const N &rhs)
{
	if (rhs.value.size() > value.size())
		throw std::underflow_error("result less than zero");
	atom2 carry = 0;
	int i = 1;
	while (i < isize(rhs.value))
	{
		carry += value[i]; carry += atom(~rhs.value[i]);
		value[i++] = carry;
		carry >>= sizeof(carry)/2*8;
	}
	while (carry && i < isize(value))
	{
		carry += value[i]; carry += atom(~0);
		value[i++] = carry;
		carry >>= sizeof(carry)/2*8;
	}
	if (!carry)
		throw std::underflow_error("result less than zero");
	// TODO remove high order zeros
	return *this;
}

// multiplication
N &N::operator*=(const N &rhs)
{
	constexpr Base order = 1u<<27, scale[] = {17, 24, 26};
	constexpr Base modulus[] = {order*scale[0]+1, order*scale[1]+1, order*scale[2]+1};

	// size of result
	int n;
	int nr = value.size() + rhs.value.size();
	for (n = 1; n < nr; n *= 2);

	// convolve
	constexpr Base root[] = {0x07B285C3, 0x0DB2957C, 0x53027D11};
	Mpint xc[] = {Mpint(n), Mpint(n), Mpint(n)};
	Mpint yc;
	auto qmod = [](Base x, Base p) -> Base {Base t = x; return Base{t-p} > t ? t : t-p;};
	for (int i = 0; i < isize(value); ++i) xc[0][i] = qmod(value[i], modulus[0]);
	yc = Mpint(n);
	for (int i = 0; i < isize(rhs.value); ++i) yc[i] = qmod(rhs.value[i], modulus[0]);
	convolve<order, scale[0], root[0]>(xc[0], yc);
	for (int i = 0; i < isize(value); ++i) xc[1][i] = qmod(value[i], modulus[1]);
	yc = Mpint(n);
	for (int i = 0; i < isize(rhs.value); ++i) yc[i] = qmod(rhs.value[i], modulus[1]);
	convolve<order, scale[1], root[1]>(xc[1], yc);
	for (int i = 0; i < isize(value); ++i) xc[2][i] = qmod(value[i], modulus[2]);
	yc = Mpint(n);
	for (int i = 0; i < isize(rhs.value); ++i) yc[i] = qmod(rhs.value[i], modulus[2]);
	convolve<order, scale[2], root[2]>(xc[2], yc);

	// chinese remainder to recover results
	struct Base3
	{
		Base l;
		Base2 h;
	};
	Base3 m = {0x18000001u, 0x52e0000170800002ul};
	value.resize(nr);
	Base3 p = {0, 0};
	for (int i = 0; i < nr; ++i)
	{
		// (∏m/m)^(m-2) * a mod m * ∏m/m
		// m chosen so that these cannot overflow 96 bits
		// split to avoid 96 bit arithmetic
		Base2 x0 = 0, x1 = 0, xt;
		xt = Base2{xc[0][i]}*0x42EBAEC0u % modulus[0];
		x0 += Base2{xt}*0x90000001; x1 += Base2{xt}*0x9c000001;
		xt = Base2{xc[1][i]}*0x36DB6D8Eu % modulus[1];
		x0 += Base2{xt}*0x58000001; x1 += Base2{xt}*0x6e800001;
		xt = Base2{xc[2][i]}*0x2E38E3B4u % modulus[2];
		x0 += Base2{xt}*0x48000001; x1 += Base2{xt}*0x66000001;
		x0 += p.h; x1 += x0>>32;
		p.h = x1; p.l = x0;

		// mod
		Base3 pt;
		pt = {p.l-m.l, p.h-m.h - (p.l-m.l>p.l?1:0)};
		if (pt.h <= p.h)
		{
			p = pt;
			pt = {p.l-m.l, p.h-m.h - (p.l-m.l>p.l?1:0)};
			if (pt.h <= p.h)
				p = pt;
		}

		// accumulate
		value[i] = p.l;
	}

	// remove top entry if 0
	if (value[value.size()-1] == 0)
		value.erase(value.end()-1);

	// result
	return *this;
}
}
